"""
Phase 6: Robustness Tests
===========================
1. Goodhart-Pradhan reversal test (pre/post-2020 inflation sign)
2. Juselius-Takats U-shape (youth_dep + old_dep → inflation)
3. Japan exclusion
4. CCA exclusion
5. Subsample tests (OECD, Asia, Europe)
6. Alternative indices (PCA-based, variance-weighted)

Input:  japanification/data/processed/japan_panel_indexed.csv
Output: japanification/output/tables/phase6_*.csv
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
JAPAN_DIR = PROJECT_DIR / "japanification"
PROCESSED_DIR = JAPAN_DIR / "data" / "processed"
TABLE_DIR = JAPAN_DIR / "output" / "tables"

sys.path.insert(0, str(PROJECT_DIR / "multilateral" / "src"))
from model import PanelGLS


def fit_and_report(y, X, entity_ids, time_ids, feature_names, label):
    """Fit PanelGLS and return summary DataFrame."""
    model = PanelGLS()
    model.fit(y, X, entity_ids, time_ids)
    print(f"\n  {label}")
    model.summary(feature_names=feature_names)
    result_df = model.to_dataframe(feature_names=feature_names)
    result_df['model'] = label
    return model, result_df


# Country group definitions
CCA_COUNTRIES = ['ARM', 'AZE', 'GEO', 'KAZ', 'KGZ', 'TJK', 'TKM', 'UZB']

OECD_COUNTRIES = [
    'AUS', 'AUT', 'BEL', 'CAN', 'CHL', 'COL', 'CRI', 'CZE', 'DNK',
    'EST', 'FIN', 'FRA', 'DEU', 'GRC', 'HUN', 'ISL', 'IRL', 'ISR',
    'ITA', 'JPN', 'KOR', 'LVA', 'LTU', 'LUX', 'MEX', 'NLD', 'NZL',
    'NOR', 'POL', 'PRT', 'SVK', 'SVN', 'ESP', 'SWE', 'CHE', 'TUR',
    'GBR', 'USA',
]

ASIA_COUNTRIES = [
    'JPN', 'KOR', 'CHN', 'IND', 'IDN', 'THA', 'VNM', 'MYS', 'PHL',
    'SGP', 'HKG', 'TWN', 'BGD', 'PAK', 'LKA', 'MMR', 'KHM', 'LAO',
    'MNG', 'NPL',
]

EUROPE_COUNTRIES = [
    'AUT', 'BEL', 'BGR', 'HRV', 'CYP', 'CZE', 'DNK', 'EST', 'FIN',
    'FRA', 'DEU', 'GRC', 'HUN', 'IRL', 'ITA', 'LVA', 'LTU', 'LUX',
    'MLT', 'NLD', 'NOR', 'POL', 'PRT', 'ROU', 'SVK', 'SVN', 'ESP',
    'SWE', 'CHE', 'GBR',
]


def main():
    print("=" * 70)
    print("PHASE 6: Robustness Tests")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "japan_panel_indexed.csv")
    print(f"Loaded: {len(df):,} obs, {df['iso3'].nunique()} countries")

    dep_var = 'japan_index_2c'
    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'kaopen', 'log_rel_opw', 'nfa_gdp_lag']
    base_vars = demo_vars + controls

    all_results = []

    # Reference baseline
    est = df.dropna(subset=[dep_var] + base_vars).copy()
    m_base, r_base = fit_and_report(
        est[dep_var].values, est[base_vars].values,
        est['iso3'].values, est['year'].values,
        base_vars, "Reference Baseline"
    )
    all_results.append(r_base)

    # =================================================================
    # 1. Goodhart-Pradhan reversal test
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  1. Goodhart-Pradhan Reversal Test")
    print("=" * 70)

    # Z → inflation: pre vs post 2020
    inf_vars = demo_vars + controls
    for label, mask in [('Pre-2020 (1990-2019)', df['year'] < 2020),
                        ('Post-2020 (2020-2024)', df['year'] >= 2020)]:
        sub = df[mask].dropna(subset=['inflation_japan'] + inf_vars)
        if len(sub) >= 50:
            m_gp, r_gp = fit_and_report(
                sub['inflation_japan'].values, sub[inf_vars].values,
                sub['iso3'].values, sub['year'].values,
                inf_vars, f"Goodhart-Pradhan: Z→Inflation {label}"
            )
            all_results.append(r_gp)

    # By income group
    if 'gdp_pc_ppp' in df.columns:
        median_gdp = df['gdp_pc_ppp'].median()
        for inc_label in ['AE (above median GDP)', 'EM (below median GDP)']:
            inc_mask = (df['gdp_pc_ppp'] >= median_gdp if 'above' in inc_label
                        else df['gdp_pc_ppp'] < median_gdp)
            for period_label, period_mask in [('Pre-2020', df['year'] < 2020),
                                               ('Post-2020', df['year'] >= 2020)]:
                sub = df[inc_mask & period_mask].dropna(subset=['inflation_japan'] + inf_vars)
                if len(sub) >= 50:
                    m_gpi, r_gpi = fit_and_report(
                        sub['inflation_japan'].values, sub[inf_vars].values,
                        sub['iso3'].values, sub['year'].values,
                        inf_vars, f"GP: {inc_label} {period_label}"
                    )
                    all_results.append(r_gpi)

    # =================================================================
    # 2. Juselius-Takats U-shape
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  2. Juselius-Takats U-Shape Test")
    print("=" * 70)

    if 'youth_dep' in df.columns and 'old_dep' in df.columns:
        jt_vars = ['youth_dep', 'old_dep', 'working_age_share'] + controls
        avail_jt = [v for v in jt_vars if v in df.columns]
        est_jt = df.dropna(subset=['inflation_japan'] + avail_jt)

        if len(est_jt) >= 200:
            m_jt, r_jt = fit_and_report(
                est_jt['inflation_japan'].values, est_jt[avail_jt].values,
                est_jt['iso3'].values, est_jt['year'].values,
                avail_jt, "Juselius-Takats: Dep Ratios → Inflation"
            )
            all_results.append(r_jt)

            print("\n  Expected: youth_dep (+), old_dep (+), working_age_share (-)")
            print("  This gives U-shape: both young and old are inflationary")

    # =================================================================
    # 3. Japan exclusion
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  3. Japan Exclusion")
    print("=" * 70)

    no_jpn = df[df['iso3'] != 'JPN'].dropna(subset=[dep_var] + base_vars)
    m_nj, r_nj = fit_and_report(
        no_jpn[dep_var].values, no_jpn[base_vars].values,
        no_jpn['iso3'].values, no_jpn['year'].values,
        base_vars, "Excluding Japan"
    )
    all_results.append(r_nj)

    # =================================================================
    # 4. CCA exclusion
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  4. CCA Exclusion")
    print("=" * 70)

    no_cca = df[~df['iso3'].isin(CCA_COUNTRIES)].dropna(subset=[dep_var] + base_vars)
    m_nc, r_nc = fit_and_report(
        no_cca[dep_var].values, no_cca[base_vars].values,
        no_cca['iso3'].values, no_cca['year'].values,
        base_vars, "Excluding CCA"
    )
    all_results.append(r_nc)

    # =================================================================
    # 5. Subsample tests
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  5. Subsample Tests")
    print("=" * 70)

    for label, countries in [('OECD', OECD_COUNTRIES),
                              ('Asia', ASIA_COUNTRIES),
                              ('Europe', EUROPE_COUNTRIES)]:
        sub = df[df['iso3'].isin(countries)].dropna(subset=[dep_var] + base_vars)
        if len(sub) >= 100:
            m_sub, r_sub = fit_and_report(
                sub[dep_var].values, sub[base_vars].values,
                sub['iso3'].values, sub['year'].values,
                base_vars, f"Subsample: {label}"
            )
            all_results.append(r_sub)
        else:
            print(f"  {label}: insufficient obs ({len(sub)})")

    # =================================================================
    # 6. Alternative indices
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  6. Alternative Index Specifications")
    print("=" * 70)

    # 6a. Variance-weighted index
    # Weight each component inversely by its variance (so low-variance
    # components get more weight — they're more informative signals)
    g_var = df['z_growth'].var() if 'z_growth' in df.columns else 1
    i_var = df['z_inflation'].var() if 'z_inflation' in df.columns else 1

    if g_var > 0 and i_var > 0:
        w_g = (1/g_var) / (1/g_var + 1/i_var)
        w_i = (1/i_var) / (1/g_var + 1/i_var)
        df['japan_index_varwt'] = -(w_g * df['z_growth'] + w_i * df['z_inflation'])

        est_vw = df.dropna(subset=['japan_index_varwt'] + base_vars)
        if len(est_vw) >= 200:
            m_vw, r_vw = fit_and_report(
                est_vw['japan_index_varwt'].values, est_vw[base_vars].values,
                est_vw['iso3'].values, est_vw['year'].values,
                base_vars, "Alt Index: Variance-Weighted"
            )
            all_results.append(r_vw)

    # 6b. PCA-based index
    from numpy.linalg import eigh

    pca_cols = ['z_growth', 'z_inflation']
    avail_pca = [c for c in pca_cols if c in df.columns]
    pca_data = df[avail_pca].dropna()

    if len(pca_data) >= 200 and len(avail_pca) >= 2:
        X_pca = pca_data.values
        cov = np.cov(X_pca.T)
        eigenvalues, eigenvectors = eigh(cov)

        # First PC (largest eigenvalue — last in eigh output)
        pc1 = eigenvectors[:, -1]
        # Ensure higher PC1 = more Japanified (flip if needed)
        if pc1.sum() > 0:
            pc1 = -pc1  # We want low growth/inflation = high index

        df['japan_index_pca'] = np.nan
        mask = df[avail_pca].notna().all(axis=1)
        df.loc[mask, 'japan_index_pca'] = df.loc[mask, avail_pca].values @ pc1

        est_pca = df.dropna(subset=['japan_index_pca'] + base_vars)
        if len(est_pca) >= 200:
            m_pca, r_pca = fit_and_report(
                est_pca['japan_index_pca'].values, est_pca[base_vars].values,
                est_pca['iso3'].values, est_pca['year'].values,
                base_vars, "Alt Index: PCA"
            )
            all_results.append(r_pca)

        print(f"\n  PCA loadings: growth={pc1[0]:.3f}, inflation={pc1[1]:.3f}")
        print(f"  Variance explained by PC1: {eigenvalues[-1]/eigenvalues.sum()*100:.1f}%")

    # 6c. Rolling index version
    if 'japan_index_rolling' in df.columns:
        est_roll = df.dropna(subset=['japan_index_rolling'] + base_vars)
        if len(est_roll) >= 200:
            m_roll, r_roll = fit_and_report(
                est_roll['japan_index_rolling'].values, est_roll[base_vars].values,
                est_roll['iso3'].values, est_roll['year'].values,
                base_vars, "Alt Index: Rolling 5yr MA"
            )
            all_results.append(r_roll)

    # =================================================================
    # Summary comparison
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  ROBUSTNESS SUMMARY")
    print("=" * 70)

    results_df = pd.concat(all_results, ignore_index=True)

    # Extract Z_1 coefficient across all models for comparison
    z1_compare = (results_df[results_df['variable'] == 'Z_1']
                  [['model', 'coefficient', 'std_error', 'p_value']]
                  .sort_values('model'))
    print("\n  Z_1 coefficient across specifications:")
    print(z1_compare.to_string(index=False, float_format='%.4f'))

    # Save
    results_df.to_csv(TABLE_DIR / "phase6_robustness_results.csv", index=False)
    z1_compare.to_csv(TABLE_DIR / "phase6_z1_comparison.csv", index=False)

    print(f"\n{'=' * 70}")
    print(f"Phase 6 complete. {results_df['model'].nunique()} specifications estimated.")
    print(f"Tables saved to {TABLE_DIR}")
    print("=" * 70)


if __name__ == "__main__":
    main()
